# -*-coding:utf-8-*-
# date:2021-04-15
# author: likecy
# function : interfence

import os
import argparse
import torch
import torch.nn as nn
from data_iter.datasets import letterbox
import numpy as np

import time
import datetime
import os
import math
from datetime import datetime
import cv2
import torch.nn.functional as F
import xml.etree.cElementTree as ET
from models.build_model import *
from utils.model_utils import *
from sklearn.metrics import precision_score, recall_score, f1_score

MODEL_NAMES = {
    'alexnet': Alexnet,
    'googlenet': Googlenet,
    'resnet18': Resnet18,
    'resnet50': Resnet50,
    'resnet101': Resnet101,
    'resnet152': Resnet152,
    'resnext101_32x8d': Resnext101_32x8d,
    'resnext101_32x16d': Resnext101_32x16d,
    'resnext101_32x48d': Resnext101_32x48d,
    'resnext101_32x32d': Resnext101_32x32d,
    'densenet121': Densenet121,
    'densenet169': Densenet169,
    'moblienetv2': Mobilenetv2,
    'efficientnet-b7': Efficientnet,
    'efficientnet-b0': Efficientnet,
    'efficientnet-b8': Efficientnet,
    'squeezenet1_0': Squeezenet1_0,
    'squeezenet1_1': Squeezenet1_1,
    'shufflenet_v2_x0_5': Shufflenet_v2_x0_5,
    'shufflenet_v2_x1_0': Shufflenet_v2_x1_0,
    'shufflenet_v2_x1_5': Shufflenet_v2_x1_5,
    'shufflenet_v2_x2_0': Shufflenet_v2_x2_0
}


def get_xml_msg(path):
    list_x = []
    tree = ET.parse(path)  # 解析 xml 文件
    root = tree.getroot()
    for Object in root.findall('object'):
        name = Object.find('name').text
        # ----------------------------
        bndbox = Object.find('bndbox')
        xmin = np.float32((bndbox.find('xmin').text))
        ymin = np.float32((bndbox.find('ymin').text))
        xmax = np.float32((bndbox.find('xmax').text))
        ymax = np.float32((bndbox.find('ymax').text))
        bbox = int(xmin), int(ymin), int(xmax), int(ymax)
        xyxy = xmin, ymin, xmax, ymax
        list_x.append((name, xyxy))
    return list_x


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description=' Project Classification top1 Test')
    parser.add_argument('--test_model', type=str, default='./weights/squeezenet1_1-size-224_epoch-50.pth',
                        help='test_model')  # 模型路径
    parser.add_argument('--model', type=str, default='squeezenet1_1',
                        help='model :参考models/build_model 中的模型 MODEL_NAMES')  # 模型类型
    parser.add_argument('--num_classes', type=int, default=2,
                        help='num_classes')  # 分类类别个数
    parser.add_argument('--GPUS', type=str, default='0',
                        help='GPUS')  # GPU选择
    parser.add_argument('--test_path', type=str, default='/Users/chenyun/xjdlab/CT数据识别/tests/',
                        help='test_path')  # 测试集路径
    parser.add_argument('--img_size', type=tuple, default=(224, 224),
                        help='img_size')  # 输入模型图片尺寸
    parser.add_argument('--fix_res', type=bool, default=False,
                        help='fix_resolution')  # 输入模型样本图片是否保证图像分辨率的长宽比
    parser.add_argument('--have_label_file', type=bool, default=False,
                        help='have_label_file')  # 是否可视化图片
    parser.add_argument('--vis', type=bool, default=False,
                        help='vis')  # 是否可视化图片

    print('\n/******************* {} ******************/\n'.format(parser.description))
    # --------------------------------------------------------------------------
    ops = parser.parse_args()  # 解析添加参数
    # --------------------------------------------------------------------------
    print('----------------------------------')

    unparsed = vars(ops)  # parse_args()方法的返回值为namespace，用vars()内建函数化为字典
    for key in unparsed.keys():
        print('{} : {}'.format(key, unparsed[key]))

    # ---------------------------------------------------------------------------
    os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS

    test_path = ops.test_path  # 测试图片文件夹路径

    #---------------------------------------------------------------- 构建模型
    print('use model : %s' % (ops.model))

    if not ops.model.startswith('efficientnet'):
        model_ = MODEL_NAMES[ops.model](num_classes=ops.num_classes)
    else:
        model_ = MODEL_NAMES[ops.model](
            model_name=ops.model, num_classes=ops.num_classes)
        # print(model_)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    model_ = model_.to(device)

    # print(model_)  # 打印模型结构
    # break

    # 加载测试模型
    # if os.access(ops.test_model, os.F_OK):  # checkpoint
    chkpt = torch.load(ops.test_model, map_location=device)
    model_.load_state_dict(chkpt)
    print('load test model : {}'.format(ops.test_model))
    # chkpt = torch.load(
    #     "/Users/chenyun/xjdlab/classification/weights/densenet121-size-224_epoch-20.pth", map_location=device)
    # model_.load_state_dict(chkpt)

    print(model_)  # 打印模型结构
    model_.eval()  # 设置为前向推断模式
    # ----------------------------------------------------------------
    y_true = []
    y_pred = []
    dict_r = {}
    dict_p = {}
    dict_static = {}
    test_path = os.listdir(ops.test_path)
    for i in test_path:
        print(i)
        if(i == ".DS_Store"):
            del i
    for idx, doc in enumerate(sorted(test_path, key=lambda x: int(x.split('-')[0]), reverse=False)):
        if doc not in dict_static.keys():
            dict_static[idx] = doc
            dict_r[doc] = 0
            dict_p[doc] = 0
    #---------------------------------------------------------------- 预测图片
    print(dict_static)
    # # return None
    # print("*"*80)
    # breakpoint
    font = cv2.FONT_HERSHEY_SIMPLEX
    with torch.no_grad():
        for idx, doc in enumerate(sorted(os.listdir(ops.test_path), key=lambda x: int(x.split('-')[0]), reverse=False)):
            i = 0
            gt_label = idx
            for file in os.listdir(ops.test_path+doc):
                if ".jpg" not in file:
                    continue
                print('------>>> {} - gt_label : {}'.format(file, gt_label))

                img = cv2.imread(ops.test_path + doc+'/' + file)

                if ops.fix_res:
                    img_ = letterbox(
                        img, size_=ops.img_size[0], mean_rgb=(128, 128, 128))
                else:
                    img_ = cv2.resize(
                        img, (ops.img_size[1], ops.img_size[0]), interpolation=cv2.INTER_CUBIC)
                if ops.vis:
                    cv2.namedWindow('image', 0)
                    cv2.imshow('image', img_)
                    cv2.waitKey(1)
                img_ = img_.astype(np.float32)
                img_ = (img_-128.)/256.

                img_ = img_.transpose(2, 0, 1)
                img_ = torch.from_numpy(img_)
                img_ = img_.unsqueeze_(0)

                if use_cuda:
                    img_ = img_.cuda()  # (bs, 3, h, w)

                pre_ = model_(img_.float())

                outputs = F.softmax(pre_, dim=1)
                outputs = outputs[0]

                output = outputs.cpu().detach().numpy()
                output = np.array(output)

                max_index = np.argmax(output)
                score_ = output[max_index]
                y_true.append(gt_label)
                y_pred.append(max_index)
                print('gt {} - {} -- pre {}     --->>>    confidence {}'.format(doc,
                      gt_label, max_index, score_))
                dict_p[dict_static[max_index]] += 1
                if gt_label == max_index:
                    dict_r[doc] += 1

    cv2.destroyAllWindows()
    # Top1 的每类预测精确度。
    print('\n-----------------------------------------------\n')
    p = precision_score(y_true, y_pred)  # 输出结果0.5
    r = recall_score(y_true, y_pred)  # 输出结果0.333
    f1 = f1_score(y_true, y_pred)  # 输出0.4
    print("精确度：", p)
    print("召回率:", r)
    print("F1值", f1)
    fs = open('./log/'+ops.model +
              '_interface_result.txt', "w", encoding='utf-8')
    fs.write('\n-----------------------------------------------\n')
    fs.write("eval form :"+ops.test_model)
    fs.write('\n-----------------------------------------------\n')
    fs.write("\n 精确度:{}".format(p))
    fs.write("\n 召回率:{}".format(r))
    fs.write("\n F1_source值:{}".format(f1))

    acc_list = []
    test_path = os.listdir(ops.test_path)
    for i in test_path:
        print(i)
        if(i == ".DS_Store"):
            del i
    for idx, doc in enumerate(sorted(test_path, key=lambda x: int(x.split('-')[0]), reverse=False)):
        fm = dict_p[doc]
        fz = dict_r[doc]
        acc_list.append(fz/fm)
        try:
            fs.write('\n {}: {}'.format(doc, fz/fm))
            print('{}: {}'.format(doc, fz/fm))
        except:
            print('error')

    fs.write("\n MAP : {}".format(np.mean(acc_list)))
    fs.write('\n-----------------------------------------------\n')
    print("\n MAP : {}".format(np.mean(acc_list)))

    print('\nwell done ')
    fs.close()
